use std::{
    fmt,
    io::{IoSlice, Read, Write},
    pin::Pin,
    task::{Context as TaskContext, Poll},
};

use pin_project_lite::pin_project;
use rama_core::{
    Service,
    error::{BoxError, ErrorContext, OpaqueError},
    extensions::ExtensionsMut,
    service::RejectService,
    stream::{HeapReader, PeekStream, StackReader},
    telemetry::tracing,
};
use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};

use crate::{address::Domain, tls::client::extract_sni_from_client_hello_handshake};

use super::{NoTlsRejectError, TlsPeekStream};

/// A [`Service`] router that can be used to support
/// routing of tls traffic as well as non-tls traffic.
///
/// The difference with [`TlsPeekRouter`] is that the [`SniRouter`]
/// continues to parse after the initial couple of bytes,
/// in order to learn more about underlying traffic. This allowing
/// you to route based on the SNI from the client hello,
/// among other capabilities.
///
/// By default non-tls traffic is rejected using [`RejectService`].
/// Use [`SniRouter::with_fallback`] to configure the fallback service.
pub struct SniRouter<S, F = RejectService<(), NoTlsRejectError>> {
    service: S,
    fallback: F,
}

impl<S> SniRouter<S> {
    /// Create a new [`SniRouter`].
    pub fn new(service: S) -> Self {
        Self {
            service,
            fallback: RejectService::new(NoTlsRejectError),
        }
    }

    /// Attach a fallback [`Service`] to this [`SniRouter`].
    ///
    /// Used in case the traffic is not Tls traffic (defined by the first bytes).
    pub fn with_fallback<F>(self, fallback: F) -> SniRouter<S, F> {
        SniRouter {
            service: self.service,
            fallback,
        }
    }
}

impl<S: Clone, F: Clone> Clone for SniRouter<S, F> {
    fn clone(&self) -> Self {
        Self {
            service: self.service.clone(),
            fallback: self.fallback.clone(),
        }
    }
}

impl<S: fmt::Debug, F: fmt::Debug> fmt::Debug for SniRouter<S, F> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("SniRouter")
            .field("service", &self.service)
            .field("fallback", &self.fallback)
            .finish()
    }
}

impl<Stream, Output, S, F> Service<Stream> for SniRouter<S, F>
where
    Stream: rama_core::stream::Stream + Unpin + ExtensionsMut,
    Output: Send + 'static,
    S: Service<SniRequest<Stream>, Output = Output, Error: Into<BoxError>>,
    F: Service<TlsPeekStream<Stream>, Output = Output, Error: Into<BoxError>>,
{
    type Output = Output;
    type Error = BoxError;

    async fn serve(&self, mut stream: Stream) -> Result<Self::Output, Self::Error> {
        let mut peek_buf = [0u8; TLS_HEADER_PEEK_LEN];
        let n = stream
            .read(&mut peek_buf)
            .await
            .context("try to read tls prefix header")?;

        let is_tls = n == TLS_HEADER_PEEK_LEN && matches!(peek_buf, [0x16, 0x03, 0x00..=0x04, ..]);
        tracing::trace!("tls prefix header read (is tls: {is_tls})");

        if !is_tls {
            let offset = TLS_HEADER_PEEK_LEN - n;
            if offset > 0 {
                tracing::trace!(
                    "move tls peek buffer cursor due to reading not enough (read: {n})"
                );
                peek_buf.copy_within(0..n, offset);
            }

            let mut peek = StackReader::new(peek_buf);
            peek.skip(offset);
            let stream = PeekStream::new(peek, stream);

            tracing::trace!("fallback to non-tls service");
            return self.fallback.serve(stream).await.map_err(Into::into);
        }

        let n = ((peek_buf[3] as usize) << 8) | (peek_buf[4] as usize);
        let record_size = (n + TLS_HEADER_PEEK_LEN).min(2048); // limit to 2k bytes, should be plenty for a record that's usually <=500 bytes

        let mut v = vec![0u8; record_size];
        v[..TLS_HEADER_PEEK_LEN].copy_from_slice(&peek_buf[..]);
        let read_size = stream
            .read(&mut v[TLS_HEADER_PEEK_LEN..])
            .await
            .context("read tls record")?;

        if read_size != n {
            tracing::debug!(
                read_size,
                "unexpected read size for client hello handshake data"
            );
            return Err(
                OpaqueError::from_display("missing client hello tls handshake data").into_boxed(),
            );
        }
        let sni = extract_sni_from_client_hello_handshake(&v)
            .context("parse client hello handshake bytes and extract SNI")?;

        let mem_reader = HeapReader::from(v);
        let peek_stream = PeekStream::new(mem_reader, stream);

        self.service
            .serve(SniRequest {
                stream: peek_stream,
                sni,
            })
            .await
            .map_err(Into::into)
    }
}

const TLS_HEADER_PEEK_LEN: usize = 5;

/// [`PeekStream`] alias used by [`SniRouter`].
pub type SniPeekStream<S> = PeekStream<HeapReader, S>;

pin_project! {
    /// A request ready for SNI routing,
    /// usually used in combination with [`SniRouter`].
    pub struct SniRequest<S> {
        #[pin]
        pub stream: SniPeekStream<S>,
        pub sni: Option<Domain>,
    }
}

#[warn(clippy::missing_trait_methods)]
impl<S> AsyncRead for SniRequest<S>
where
    S: AsyncRead,
{
    #[inline]
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut TaskContext<'_>,
        buf: &mut ReadBuf<'_>,
    ) -> Poll<std::io::Result<()>> {
        let me = self.project();
        me.stream.poll_read(cx, buf)
    }
}

#[warn(clippy::missing_trait_methods)]
impl<S> AsyncBufRead for SniRequest<S>
where
    S: AsyncBufRead,
{
    #[inline]
    fn poll_fill_buf(
        self: Pin<&mut Self>,
        cx: &mut TaskContext<'_>,
    ) -> Poll<std::io::Result<&[u8]>> {
        let me = self.project();
        me.stream.poll_fill_buf(cx)
    }

    #[inline]
    fn consume(self: Pin<&mut Self>, amt: usize) {
        let me = self.project();
        me.stream.consume(amt)
    }
}

impl<S> Read for SniRequest<S>
where
    S: Read,
{
    #[inline]
    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
        self.stream.read(buf)
    }

    fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
        self.stream.read_exact(buf)
    }

    fn read_to_end(&mut self, buf: &mut Vec<u8>) -> std::io::Result<usize> {
        self.stream.read_to_end(buf)
    }

    fn read_to_string(&mut self, buf: &mut String) -> std::io::Result<usize> {
        self.stream.read_to_string(buf)
    }

    fn read_vectored(&mut self, bufs: &mut [std::io::IoSliceMut<'_>]) -> std::io::Result<usize> {
        self.stream.read_vectored(bufs)
    }
}

#[warn(clippy::missing_trait_methods)]
impl<S> AsyncWrite for SniRequest<S>
where
    S: AsyncWrite,
{
    #[inline]
    fn poll_write(
        self: Pin<&mut Self>,
        cx: &mut TaskContext<'_>,
        buf: &[u8],
    ) -> Poll<std::io::Result<usize>> {
        let me = self.project();
        me.stream.poll_write(cx, buf)
    }

    #[inline]
    fn poll_flush(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<std::io::Result<()>> {
        let me = self.project();
        me.stream.poll_flush(cx)
    }

    #[inline]
    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<std::io::Result<()>> {
        let me = self.project();
        me.stream.poll_shutdown(cx)
    }

    #[inline]
    fn poll_write_vectored(
        self: Pin<&mut Self>,
        cx: &mut TaskContext<'_>,
        bufs: &[IoSlice<'_>],
    ) -> Poll<Result<usize, std::io::Error>> {
        let me = self.project();
        me.stream.poll_write_vectored(cx, bufs)
    }

    #[inline]
    fn is_write_vectored(&self) -> bool {
        self.stream.is_write_vectored()
    }
}

impl<S> Write for SniRequest<S>
where
    S: Write,
{
    #[inline]
    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
        self.stream.write(buf)
    }

    #[inline]
    fn flush(&mut self) -> std::io::Result<()> {
        self.stream.flush()
    }

    #[inline]
    fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
        self.stream.write_all(buf)
    }

    #[inline]
    fn write_fmt(&mut self, args: fmt::Arguments<'_>) -> std::io::Result<()> {
        self.stream.write_fmt(args)
    }

    #[inline]
    fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> std::io::Result<usize> {
        self.stream.write_vectored(bufs)
    }
}

impl<S: Clone> Clone for SniRequest<S> {
    fn clone(&self) -> Self {
        Self {
            stream: self.stream.clone(),
            sni: self.sni.clone(),
        }
    }
}

impl<S: fmt::Debug> fmt::Debug for SniRequest<S> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("SniRequest")
            .field("stream", &self.stream)
            .field("sni", &self.sni)
            .finish()
    }
}

#[cfg(test)]
mod test {
    use rama_core::{
        ServiceInput,
        service::{RejectError, service_fn},
    };
    use std::convert::Infallible;

    use rama_core::stream::Stream;

    use super::*;

    const CH_ONE_ONE_ONE_ONE: &[u8] = &[
        0x16, 0x03, 0x01, 0x02, 0x00, 0x01, 0x00, 0x01, 0xfc, 0x03, 0x03, 0x02, 0x15, 0xfd, 0xe2,
        0x92, 0xc0, 0x46, 0x9f, 0x92, 0xbe, 0xd7, 0xe9, 0x1a, 0x3c, 0x50, 0x5e, 0x55, 0x49, 0x17,
        0xa6, 0xf8, 0xa5, 0xca, 0xa4, 0x6d, 0x60, 0xcc, 0xea, 0xf7, 0x25, 0xf0, 0x6e, 0x20, 0x41,
        0x20, 0x18, 0x66, 0x5c, 0xae, 0x08, 0xb0, 0x10, 0x96, 0x3c, 0xad, 0xb4, 0x13, 0xe1, 0x92,
        0xce, 0x96, 0xad, 0x9d, 0x45, 0x05, 0xb7, 0xa6, 0x4c, 0x01, 0x71, 0x08, 0x74, 0x0d, 0x1f,
        0x35, 0x00, 0x2a, 0x3a, 0x3a, 0x13, 0x01, 0x13, 0x02, 0x13, 0x03, 0xc0, 0x2c, 0xc0, 0x2b,
        0xcc, 0xa9, 0xc0, 0x30, 0xc0, 0x2f, 0xcc, 0xa8, 0xc0, 0x0a, 0xc0, 0x09, 0xc0, 0x14, 0xc0,
        0x13, 0x00, 0x9d, 0x00, 0x9c, 0x00, 0x35, 0x00, 0x2f, 0xc0, 0x08, 0xc0, 0x12, 0x00, 0x0a,
        0x01, 0x00, 0x01, 0x89, 0xda, 0xda, 0x00, 0x00, 0x00, 0x00, 0x00, 0x14, 0x00, 0x12, 0x00,
        0x00, 0x0f, 0x6f, 0x6e, 0x65, 0x2e, 0x6f, 0x6e, 0x65, 0x2e, 0x6f, 0x6e, 0x65, 0x2e, 0x6f,
        0x6e, 0x65, 0x00, 0x17, 0x00, 0x00, 0xff, 0x01, 0x00, 0x01, 0x00, 0x00, 0x0a, 0x00, 0x0c,
        0x00, 0x0a, 0xfa, 0xfa, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19, 0x00, 0x0b, 0x00,
        0x02, 0x01, 0x00, 0x00, 0x05, 0x00, 0x05, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d, 0x00,
        0x16, 0x00, 0x14, 0x04, 0x03, 0x08, 0x04, 0x04, 0x01, 0x05, 0x03, 0x08, 0x05, 0x08, 0x05,
        0x05, 0x01, 0x08, 0x06, 0x06, 0x01, 0x02, 0x01, 0x00, 0x12, 0x00, 0x00, 0x00, 0x33, 0x00,
        0x2b, 0x00, 0x29, 0xfa, 0xfa, 0x00, 0x01, 0x00, 0x00, 0x1d, 0x00, 0x20, 0x7c, 0xe1, 0xc6,
        0xc2, 0x01, 0x69, 0x42, 0xba, 0x2b, 0xec, 0x07, 0x2f, 0x04, 0xbd, 0xb6, 0x2a, 0x7e, 0x04,
        0x6b, 0x96, 0x98, 0x51, 0x4e, 0x80, 0xb3, 0x2a, 0x4c, 0x4f, 0x1f, 0x39, 0x82, 0x2b, 0x00,
        0x2d, 0x00, 0x02, 0x01, 0x01, 0x00, 0x2b, 0x00, 0x0b, 0x0a, 0x6a, 0x6a, 0x03, 0x04, 0x03,
        0x03, 0x03, 0x02, 0x03, 0x01, 0x00, 0x1b, 0x00, 0x03, 0x02, 0x00, 0x01, 0x3a, 0x3a, 0x00,
        0x01, 0x00, 0x00, 0x15, 0x00, 0xd3, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
    ];

    const TLS_BUT_NO_SNI: &[u8] = &[
        0x16, 0x03, 0x01, 0x02, 0x00, 0x01, 0x00, 0x01, 0xfc, 0x03, 0x03, 0x28, 0x5b, 0x8f, 0x90,
        0x22, 0x2a, 0x90, 0x95, 0x89, 0xa9, 0x62, 0x1f, 0xdb, 0x68, 0xbe, 0x4c, 0x0e, 0xdf, 0xe4,
        0x76, 0x50, 0x48, 0xa5, 0x40, 0x56, 0x5f, 0x9a, 0xba, 0x19, 0x29, 0x66, 0xdd, 0x20, 0x7a,
        0x7f, 0x7e, 0xc7, 0xbd, 0xfb, 0x88, 0x07, 0xd9, 0xf5, 0x99, 0xfa, 0xf3, 0x0d, 0x37, 0x30,
        0x52, 0x4d, 0x44, 0xe4, 0x26, 0xc0, 0xd1, 0x9a, 0xcd, 0x78, 0xf6, 0x7a, 0xf1, 0x7a, 0x66,
        0xe1, 0x00, 0x3e, 0x13, 0x02, 0x13, 0x03, 0x13, 0x01, 0xc0, 0x2c, 0xc0, 0x30, 0x00, 0x9f,
        0xcc, 0xa9, 0xcc, 0xa8, 0xcc, 0xaa, 0xc0, 0x2b, 0xc0, 0x2f, 0x00, 0x9e, 0xc0, 0x24, 0xc0,
        0x28, 0x00, 0x6b, 0xc0, 0x23, 0xc0, 0x27, 0x00, 0x67, 0xc0, 0x0a, 0xc0, 0x14, 0x00, 0x39,
        0xc0, 0x09, 0xc0, 0x13, 0x00, 0x33, 0x00, 0x9d, 0x00, 0x9c, 0x00, 0x3d, 0x00, 0x3c, 0x00,
        0x35, 0x00, 0x2f, 0x00, 0xff, 0x01, 0x00, 0x01, 0x75, 0x00, 0x0b, 0x00, 0x04, 0x03, 0x00,
        0x01, 0x02, 0x00, 0x0a, 0x00, 0x16, 0x00, 0x14, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x1e, 0x00,
        0x19, 0x00, 0x18, 0x01, 0x00, 0x01, 0x01, 0x01, 0x02, 0x01, 0x03, 0x01, 0x04, 0x00, 0x10,
        0x00, 0x0e, 0x00, 0x0c, 0x02, 0x68, 0x32, 0x08, 0x68, 0x74, 0x74, 0x70, 0x2f, 0x31, 0x2e,
        0x31, 0x00, 0x16, 0x00, 0x00, 0x00, 0x17, 0x00, 0x00, 0x00, 0x31, 0x00, 0x00, 0x00, 0x0d,
        0x00, 0x2a, 0x00, 0x28, 0x04, 0x03, 0x05, 0x03, 0x06, 0x03, 0x08, 0x07, 0x08, 0x08, 0x08,
        0x09, 0x08, 0x0a, 0x08, 0x0b, 0x08, 0x04, 0x08, 0x05, 0x08, 0x06, 0x04, 0x01, 0x05, 0x01,
        0x06, 0x01, 0x03, 0x03, 0x03, 0x01, 0x03, 0x02, 0x04, 0x02, 0x05, 0x02, 0x06, 0x02, 0x00,
        0x2b, 0x00, 0x05, 0x04, 0x03, 0x04, 0x03, 0x03, 0x00, 0x2d, 0x00, 0x02, 0x01, 0x01, 0x00,
        0x33, 0x00, 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0xe0, 0xb9, 0xfb, 0x5a, 0xd5, 0x60,
        0x30, 0x39, 0xad, 0xfb, 0xd3, 0x94, 0xa2, 0xff, 0x08, 0x71, 0x9b, 0xcc, 0x6f, 0xbe, 0x9e,
        0xcc, 0x7b, 0xad, 0x3c, 0xd0, 0xde, 0xe8, 0x3e, 0x5d, 0xba, 0x6b, 0x00, 0x15, 0x00, 0xca,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
    ];

    #[tokio::test]
    async fn test_sni_router() {
        let tls_service = service_fn(async |req: SniRequest<_>| {
            let sni = req.sni.map(|sni| sni.to_string());
            Ok::<_, Infallible>(sni)
        });
        let plain_service = service_fn(async || Ok::<_, Infallible>(Some("plain".to_owned())));

        let peek_tls_svc = SniRouter::new(tls_service).with_fallback(plain_service);

        let response = peek_tls_svc
            .serve(ServiceInput::new(std::io::Cursor::new(b"".to_vec())))
            .await
            .unwrap();
        assert_eq!(Some("plain".to_owned()), response);

        let response = peek_tls_svc
            .serve(ServiceInput::new(std::io::Cursor::new(
                CH_ONE_ONE_ONE_ONE.to_vec(),
            )))
            .await
            .unwrap();
        assert_eq!(Some("one.one.one.one".to_owned()), response);

        let response = peek_tls_svc
            .serve(ServiceInput::new(std::io::Cursor::new(b"foo".to_vec())))
            .await
            .unwrap();
        assert_eq!(Some("plain".to_owned()), response);

        let response = peek_tls_svc
            .serve(ServiceInput::new(std::io::Cursor::new(b"foobar".to_vec())))
            .await
            .unwrap();
        assert_eq!(Some("plain".to_owned()), response);

        let response = peek_tls_svc
            .serve(ServiceInput::new(std::io::Cursor::new(
                TLS_BUT_NO_SNI.to_vec(),
            )))
            .await
            .unwrap();
        assert_eq!(None, response);
    }

    #[tokio::test]
    async fn test_peek_router_read_eof() {
        async fn tls_service_fn(
            SniRequest { mut stream, sni }: SniRequest<impl Stream + Unpin>,
        ) -> Result<&'static str, BoxError> {
            let mut v = Vec::default();
            let _ = stream.read_to_end(&mut v).await?;
            assert_eq!(CH_ONE_ONE_ONE_ONE, v);
            assert!(sni.is_some());
            assert_eq!("one.one.one.one", sni.unwrap());
            Ok("ok")
        }
        let tls_service = service_fn(tls_service_fn);

        let peek_tls_svc =
            SniRouter::new(tls_service).with_fallback(
                RejectService::<&'static str, RejectError>::new(RejectError::default()),
            );

        let response = peek_tls_svc
            .serve(ServiceInput::new(std::io::Cursor::new(
                CH_ONE_ONE_ONE_ONE.to_vec(),
            )))
            .await
            .unwrap();
        assert_eq!("ok", response);
    }

    #[tokio::test]
    async fn test_peek_router_read_no_tls_eof() {
        let cases = ["", "foo", "abcd", "abcde", "foobarbazbananas"];
        for content in cases {
            async fn tls_service_fn() -> Result<Vec<u8>, BoxError> {
                Ok("tls".as_bytes().to_vec())
            }
            let tls_service = service_fn(tls_service_fn);

            async fn plain_service_fn(
                mut stream: impl Stream + Unpin,
            ) -> Result<Vec<u8>, BoxError> {
                let mut v = Vec::default();
                let _ = stream.read_to_end(&mut v).await?;
                Ok(v)
            }
            let plain_service = service_fn(plain_service_fn);

            let peek_tls_svc = SniRouter::new(tls_service).with_fallback(plain_service);

            let response = peek_tls_svc
                .serve(ServiceInput::new(std::io::Cursor::new(
                    content.as_bytes().to_vec(),
                )))
                .await
                .unwrap();

            assert_eq!(content.as_bytes(), &response[..]);
        }
    }

    #[tokio::test]
    async fn test_peek_router_read_tls_no_sni_eof() {
        async fn tls_service_fn(
            SniRequest { mut stream, sni }: SniRequest<impl Stream + Unpin>,
        ) -> Result<&'static str, BoxError> {
            let mut v = Vec::default();
            let _ = stream.read_to_end(&mut v).await?;
            assert_eq!(TLS_BUT_NO_SNI, v);
            assert!(sni.is_none());
            Ok("ok")
        }
        let tls_service = service_fn(tls_service_fn);

        let peek_tls_svc =
            SniRouter::new(tls_service).with_fallback(
                RejectService::<&'static str, RejectError>::new(RejectError::default()),
            );

        let response = peek_tls_svc
            .serve(ServiceInput::new(std::io::Cursor::new(
                TLS_BUT_NO_SNI.to_vec(),
            )))
            .await
            .unwrap();
        assert_eq!("ok", response);
    }
}
